{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "colab": {
   "name": "my_first_few_shot_classifier.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "accelerator": "GPU"
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    },
    "id": "CDlQJEY27Krj"
   },
   "source": [
    "# Your own few-shot classification model ready in 15mn with PyTorch\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb)\n",
    "\n",
    "I have been working on few-shot classification for a while now. The more I talk about it, the more the people around me seem to feel that it's some kind of dark magic. Even sadder: I noticed that very few actually used it on their projects. I think that's too bad, so I decided to make a tutorial so you'll have no excuse to deprive yourself of the power of few-shot learning methods.\n",
    "\n",
    "In 15 minutes and just a few lines of code, we are going to implement\n",
    "the [Prototypical Networks](https://arxiv.org/abs/1703.05175). It's the favorite method of\n",
    "many few-shot learning researchers (~2000 citations in 3 years), because 1) it works well,\n",
    "and 2) it's incredibly easy to grasp and to implement."
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Discovering Prototypical Networks\n",
    "First, let's install the [tutorial GitHub repository](https://github.com/sicara/easy-few-shot-learning) and import some packages. If you're on Colab right now, you should also check that you're using a GPU (Edit > Notebook settings)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install easyfsl"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import Omniglot\n",
    "from torchvision.models import resnet18\n",
    "from tqdm import tqdm\n",
    "\n",
    "from easyfsl.samplers import TaskSampler\n",
    "from easyfsl.utils import plot_images, sliding_average"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now, we need a dataset. I suggest we use [Omniglot](https://github.com/brendenlake/omniglot), a popular MNIST-like benchmark\n",
    "for few-shot classification. It contains 1623 characters from 50 different alphabets. Each character has been written by\n",
    "20 different people. \n",
    "\n",
    "Bonus: it's part of the `torchivision` package, so it's very easy to download\n",
    "and work with."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "image_size = 28\n",
    "\n",
    "# NB: background=True selects the train set, background=False selects the test set\n",
    "# It's the nomenclature from the original paper, we just have to deal with it\n",
    "\n",
    "train_set = Omniglot(\n",
    "    root=\"./data\",\n",
    "    background=True,\n",
    "    transform=transforms.Compose(\n",
    "        [\n",
    "            transforms.Grayscale(num_output_channels=3),\n",
    "            transforms.RandomResizedCrop(image_size),\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "        ]\n",
    "    ),\n",
    "    download=True,\n",
    ")\n",
    "test_set = Omniglot(\n",
    "    root=\"./data\",\n",
    "    background=False,\n",
    "    transform=transforms.Compose(\n",
    "        [\n",
    "            # Omniglot images have 1 channel, but our model will expect 3-channel images\n",
    "            transforms.Grayscale(num_output_channels=3),\n",
    "            transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),\n",
    "            transforms.CenterCrop(image_size),\n",
    "            transforms.ToTensor(),\n",
    "        ]\n",
    "    ),\n",
    "    download=True,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Let's take some time to grasp what few-shot classification is. Simply put, in a few-shot classification task, you have a labeled support set (which kind of acts\n",
    "like a catalog) and query set. For each image of the query set, we want to predict a label from the\n",
    "labels present in the support set. A few-shot classification model has to use the information from the\n",
    "support set in order to classify query images. We say *few-shot* when the support set contains very\n",
    "few images for each label (typically less than 10). The figure below shows a 3-way 2-shots classification task. \"3-way\" means \"3 different classes\" and \"2-shots\" means \"2 examples per class\".\n",
    "We expect a model that has never seen any Saint-Bernard, Pug or Labrador during its training to successfully\n",
    "predict the query labels. The support set is the only information that the model has regarding what a Saint-Bernard,\n",
    "a Pug or a Labrador can be.\n",
    "\n",
    "![few-shot classification task](https://images.ctfassets.net/be04ylp8y0qc/bZhboqYXfYeW4I88xmMNv/7c5efdc368206feaad045c674b1ced95/1_AteD0yXLkQ1BbjQTB3Ytwg.png?fm=webp)\n",
    "\n",
    "Most few-shot classification methods are *metric-based*. It works in two phases : 1) they use a CNN to project both\n",
    "support and query images into a feature space, and 2) they classify query images by comparing them to support images.\n",
    "If, in the feature space, an image is closer to pugs than it is to labradors and Saint-Bernards, we will guess that\n",
    "it's a pug.\n",
    "\n",
    "From there, we have two challenges :\n",
    "\n",
    "1. Find the good feature space. This is what convolutional networks are for. A CNN is basically a function that takes an image as input and outputs a representation (or *embedding*) of this image in a given feature space. The challenge here is to have a CNN that will\n",
    "project images of the same class into representations that are close to each other, even if it has not been trained\n",
    "on objects of this class.\n",
    "2. Find a good way to compare the representations in the feature space. This is the job of Prototypical Networks.\n",
    "\n",
    "\n",
    "![Prototypical classification](https://images.ctfassets.net/be04ylp8y0qc/45M9UcUp6KnzwDaBHeGZb7/bb2dcda5942ee7320600125ac2310af6/0_M0GSRZri859fGo48.png?fm=webp)\n",
    "\n",
    "From the support set, Prototypical Networks compute a prototype for each class, which is the mean of all embeddings\n",
    "of support images from this class. Then, each query is simply classified as the nearest prototype in the feature space,\n",
    "with respect to euclidean distance.\n",
    "\n",
    "If you want to learn more about how this works, I explain it\n",
    "[there](https://www.sicara.ai/blog/2019-07-30-image-classification-few-shot-meta-learning-5fd736a6c54d2).\n",
    "But now, let's get to coding.\n",
    "In the code below (modified from [this](https://github.com/sicara/easy-few-shot-learning/blob/master/easyfsl/methods/prototypical_networks.py)), we simply define Prototypical Networks as a torch module, with a `forward()` method.\n",
    "You may notice 2 things.\n",
    "\n",
    "1. We initiate `PrototypicalNetworks` with a *backbone*. This is the feature extractor we were talking about.\n",
    "Here, we use as backbone a ResNet18 pretrained on ImageNet, with its head chopped off and replaced by a `Flatten`\n",
    "layer. The output of the backbone, for an input image, will be a 512-dimensional feature vector.\n",
    "2. The forward method doesn't only take one input tensor, but 3: in order to predict the labels of query images,\n",
    "we also need support images and labels as inputs of the model."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class PrototypicalNetworks(nn.Module):\n",
    "    def __init__(self, backbone: nn.Module):\n",
    "        super(PrototypicalNetworks, self).__init__()\n",
    "        self.backbone = backbone\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        support_images: torch.Tensor,\n",
    "        support_labels: torch.Tensor,\n",
    "        query_images: torch.Tensor,\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Predict query labels using labeled support images.\n",
    "        \"\"\"\n",
    "        # Extract the features of support and query images\n",
    "        z_support = self.backbone.forward(support_images)\n",
    "        z_query = self.backbone.forward(query_images)\n",
    "\n",
    "        # Infer the number of different classes from the labels of the support set\n",
    "        n_way = len(torch.unique(support_labels))\n",
    "        # Prototype i is the mean of all instances of features corresponding to labels == i\n",
    "        z_proto = torch.cat(\n",
    "            [\n",
    "                z_support[torch.nonzero(support_labels == label)].mean(0)\n",
    "                for label in range(n_way)\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        # Compute the euclidean distance from queries to prototypes\n",
    "        dists = torch.cdist(z_query, z_proto)\n",
    "\n",
    "        # And here is the super complicated operation to transform those distances into classification scores!\n",
    "        scores = -dists\n",
    "        return scores\n",
    "\n",
    "\n",
    "convolutional_network = resnet18(pretrained=True)\n",
    "convolutional_network.fc = nn.Flatten()\n",
    "print(convolutional_network)\n",
    "\n",
    "model = PrototypicalNetworks(convolutional_network).cuda()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now we have a model! Note that we used a pretrained feature extractor,\n",
    "so our model should already be up and running. Let's see that.\n",
    "\n",
    "Here we create a dataloader that will feed few-shot classification tasks to our model.\n",
    "But a regular PyTorch dataloader will feed batches of images, with no consideration for\n",
    "their label or whether they are support or query. We need 2 specific features in our case.\n",
    "\n",
    "1. We need images evenly distributed between a given number of classes.\n",
    "2. We need them split between support and query sets.\n",
    "\n",
    "For the first point, I wrote a custom sampler: it first samples `n_way` classes from the dataset,\n",
    "then it samples `n_shot + n_query` images for each class (for a total of `n_way * (n_shot + n_query)`\n",
    "images in each batch).\n",
    "For the second point, I have a custom collate function to replace the built-in PyTorch `collate_fn`.\n",
    "This baby feed each batch as the combination of 5 items:\n",
    "\n",
    "1. support images\n",
    "2. support labels between 0 and `n_way`\n",
    "3. query images\n",
    "4. query labels between 0 and `n_way`\n",
    "5. a mapping of each label in `range(n_way)` to its true class id in the dataset\n",
    "(it's not used by the model but it's very useful for us to know what the true class is)\n",
    "\n",
    "You can see that in PyTorch, a DataLoader is basically the combination of a sampler, a dataset and a collate function\n",
    "(and some multiprocessing voodoo): sampler says which items to fetch, the dataset says how to fetch them, and\n",
    "the collate function says how to present these items together. If you want to dive into these custom objects,\n",
    "they're [here](https://github.com/sicara/easy-few-shot-learning/tree/master/easyfsl/data_tools)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "N_WAY = 5  # Number of classes in a task\n",
    "N_SHOT = 5  # Number of images per class in the support set\n",
    "N_QUERY = 10  # Number of images per class in the query set\n",
    "N_EVALUATION_TASKS = 100\n",
    "\n",
    "# The sampler needs a dataset with a \"get_labels\" method. Check the code if you have any doubt!\n",
    "test_set.get_labels = lambda: [\n",
    "    instance[1] for instance in test_set._flat_character_images\n",
    "]\n",
    "test_sampler = TaskSampler(\n",
    "    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_set,\n",
    "    batch_sampler=test_sampler,\n",
    "    num_workers=12,\n",
    "    pin_memory=True,\n",
    "    collate_fn=test_sampler.episodic_collate_fn,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We created a dataloader that will feed us with 5-way 5-shot tasks (the most common setting in the litterature).\n",
    "Now, as every data scientist should do before launching opaque training scripts,\n",
    "let's take a look at our dataset."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "(\n",
    "    example_support_images,\n",
    "    example_support_labels,\n",
    "    example_query_images,\n",
    "    example_query_labels,\n",
    "    example_class_ids,\n",
    ") = next(iter(test_loader))\n",
    "\n",
    "plot_images(example_support_images, \"support images\", images_per_row=N_SHOT)\n",
    "plot_images(example_query_images, \"query images\", images_per_row=N_QUERY)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "For both support and query set, you should have one line for each class.\n",
    "\n",
    "How does our model perform on this task?"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "model.eval()\n",
    "example_scores = model(\n",
    "    example_support_images.cuda(),\n",
    "    example_support_labels.cuda(),\n",
    "    example_query_images.cuda(),\n",
    ").detach()\n",
    "\n",
    "_, example_predicted_labels = torch.max(example_scores.data, 1)\n",
    "\n",
    "print(\"Ground Truth / Predicted\")\n",
    "for i in range(len(example_query_labels)):\n",
    "    print(\n",
    "        f\"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}\"\n",
    "    )"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "This doesn't look bad: keep in mind that the model was trained on very different images, and has only seen 5 examples for each class!\n",
    "\n",
    "Now that we have a first idea, let's see more precisely how good our model is."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def evaluate_on_one_task(\n",
    "    support_images: torch.Tensor,\n",
    "    support_labels: torch.Tensor,\n",
    "    query_images: torch.Tensor,\n",
    "    query_labels: torch.Tensor,\n",
    ") -> [int, int]:\n",
    "    \"\"\"\n",
    "    Returns the number of correct predictions of query labels, and the total number of predictions.\n",
    "    \"\"\"\n",
    "    return (\n",
    "        torch.max(\n",
    "            model(support_images.cuda(), support_labels.cuda(), query_images.cuda())\n",
    "            .detach()\n",
    "            .data,\n",
    "            1,\n",
    "        )[1]\n",
    "        == query_labels.cuda()\n",
    "    ).sum().item(), len(query_labels)\n",
    "\n",
    "\n",
    "def evaluate(data_loader: DataLoader):\n",
    "    # We'll count everything and compute the ratio at the end\n",
    "    total_predictions = 0\n",
    "    correct_predictions = 0\n",
    "\n",
    "    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)\n",
    "    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for episode_index, (\n",
    "            support_images,\n",
    "            support_labels,\n",
    "            query_images,\n",
    "            query_labels,\n",
    "            class_ids,\n",
    "        ) in tqdm(enumerate(data_loader), total=len(data_loader)):\n",
    "\n",
    "            correct, total = evaluate_on_one_task(\n",
    "                support_images, support_labels, query_images, query_labels\n",
    "            )\n",
    "\n",
    "            total_predictions += total\n",
    "            correct_predictions += correct\n",
    "\n",
    "    print(\n",
    "        f\"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%\"\n",
    "    )\n",
    "\n",
    "\n",
    "evaluate(test_loader)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "With absolutely zero training on Omniglot images, and only 5 examples per class, we achieve around 86% accuracy! Isn't this a great start?\n",
    "\n",
    "Now that you know how to make Prototypical Networks work, you can see what happens if you tweak it\n",
    "a little bit (change the backbone, use other distances than euclidean...) or if you change the problem\n",
    "(more classes in each task, less or more examples in the support set, maybe even one example only,\n",
    "but keep in mind that in that case Prototypical Networks are just standard nearest neighbour).\n",
    "\n",
    "When you're done, you can scroll further down and learn how to **meta-train this model**, to get even better results."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training a meta-learning algorithm\n",
    "\n",
    "Let's use the \"background\" images of Omniglot as training set. Here we prepare a data loader of 40 000 few-shot classification\n",
    "tasks on which we will train our model. The alphabets used in the training set are entirely separated from those used in the testing set.\n",
    "This guarantees that at test time, the model will have to classify characters that were not seen during training.\n",
    "\n",
    "Note that we don't set a validation set here to keep this notebook concise,\n",
    "but keep in mind that **this is not good practice** and you should always use validation when training a model for production."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "N_TRAINING_EPISODES = 40000\n",
    "N_VALIDATION_TASKS = 100\n",
    "\n",
    "train_set.get_labels = lambda: [instance[1] for instance in train_set._flat_character_images]\n",
    "train_sampler = TaskSampler(\n",
    "    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES\n",
    ")\n",
    "train_loader = DataLoader(\n",
    "    train_set,\n",
    "    batch_sampler=train_sampler,\n",
    "    num_workers=12,\n",
    "    pin_memory=True,\n",
    "    collate_fn=train_sampler.episodic_collate_fn,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We will keep the same model. So our weights will be pre-trained on ImageNet. If you want to start a training from scratch,\n",
    "feel free to set `pretrained=False` in the definition of the ResNet.\n",
    "\n",
    "Here we define our loss and our optimizer (cross entropy and Adam, pretty standard), and a `fit` method.\n",
    "This method takes a classification task as input (support set and query set). It predicts the labels of the query set\n",
    "based on the information from the support set; then it compares the predicted labels to ground truth query labels,\n",
    "and this gives us a loss value. Then it uses this loss to update the parameters of the model. This is a *meta-training loop*."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "\n",
    "def fit(\n",
    "    support_images: torch.Tensor,\n",
    "    support_labels: torch.Tensor,\n",
    "    query_images: torch.Tensor,\n",
    "    query_labels: torch.Tensor,\n",
    ") -> float:\n",
    "    optimizer.zero_grad()\n",
    "    classification_scores = model(\n",
    "        support_images.cuda(), support_labels.cuda(), query_images.cuda()\n",
    "    )\n",
    "\n",
    "    loss = criterion(classification_scores, query_labels.cuda())\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    return loss.item()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "To train the model, we are just going to iterate over a large number of randomly generated few-shot classification tasks,\n",
    "and let the `fit` method update our model after each task. This is called **episodic training**.\n",
    "\n",
    "This took me 20mn on an RTX 2080 and I promised you that this whole tutorial would take 15mn.\n",
    "So if you don't want to run the training yourself, you can just skip the training and load the model that I trained\n",
    "using the exact same code."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Train the model yourself with this cell\n",
    "\n",
    "log_update_frequency = 10\n",
    "\n",
    "all_loss = []\n",
    "model.train()\n",
    "with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:\n",
    "    for episode_index, (\n",
    "        support_images,\n",
    "        support_labels,\n",
    "        query_images,\n",
    "        query_labels,\n",
    "        _,\n",
    "    ) in tqdm_train:\n",
    "        loss_value = fit(support_images, support_labels, query_images, query_labels)\n",
    "        all_loss.append(loss_value)\n",
    "\n",
    "        if episode_index % log_update_frequency == 0:\n",
    "            tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Or just load mine\n",
    "\n",
    "!wget https://public-sicara.s3.eu-central-1.amazonaws.com/easy-fsl/resnet18_with_pretraining.tar\n",
    "model.load_state_dict(torch.load(\"resnet18_with_pretraining.tar\", map_location=\"cuda\"))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now let's see if our model got better!"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "evaluate(test_loader)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Around 98%!\n",
    "\n",
    "It's not surprising that the model performs better after being further trained on Omniglot images than it was with its\n",
    "ImageNet-based parameters. However, we have to keep in mind that the classes on which we just evaluated our model were still\n",
    "**not seen during training**, so 99% (with a 12% improvement over the model trained on ImageNet) seems like a decent performance.\n",
    "\n",
    "## What have we learned?\n",
    "\n",
    "- What a Prototypical Network is and how to implement one in 15 lines of code.\n",
    "- How to use Omniglot to evaluate few-shot models\n",
    "- How to use custom PyTorch objets to sample batches in the shape of a few-shot classification tasks.\n",
    "- How to use meta-learning to train a few-shot algorithm.\n",
    "\n",
    "## What's next?\n",
    "\n",
    "- Take this notebook in your own hands, tweak everything that there is to tweak. It's the best way to understand what does what.\n",
    "- Implement other few-shot learning methods, such as Matching Networks, Relation Networks, MAML...\n",
    "- Try other ways of training. Episodic training is not the only way to train a model to generalize to new classes!\n",
    "- Experiment on other, more challenging few-shot learning benchmarks, such as [CUB](http://www.vision.caltech.edu/visipedia/CUB-200.html)\n",
    "or [Meta-Dataset](https://github.com/google-research/meta-dataset).\n",
    "- If you liked this tutorial, feel free to ⭐ [give us a star on Github](https://github.com/sicara/easy-few-shot-learning) ⭐\n",
    "- **Contribute!** The companion repository of this notebook is meant to become a boilerplate, a source of useful code that\n",
    "that newcomers can use to start their few-shot learning projects.\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ]
}